# File: predict.py
# This file is used to predict the d_uv (dissimilarity) of each edge in the graph.

import os
import networkx as nx
import gurobipy as gp
import numpy as np
from sklearn.manifold import SpectralEmbedding


def generate_ground_truth(data_dir, num_of_nodes, probability_list):
    '''
    This function generates the ground truth (appproximately optimal solution) for the input graph.
    '''
    for num in num_of_nodes:
        for prob in probability_list:
            groud_truth_file = os.path.join(data_dir, f'nodes_{num}/prob_{prob}/ground_truth.txt')
            with open(groud_truth_file, 'w') as f:
                for i in range(num):
                    for j in range(i + 1, num):
                        if i < num / 2 and j >= num / 2:
                            f.write(f"{i} {j} {1}\n")
                        else:
                            f.write(f"{i} {j} {0}\n")        


def get_OPT_from_ILP(input_graph_file, opt_file):
    '''
    This function generates the optimal solution for the input graph using ILP.
    '''
    # Load the graph
    graph = nx.read_edgelist(input_graph_file, create_using=nx.Graph(), nodetype=int)

    # gurobi license is needed here.
    options = {
        "WLSACCESSID": "",
        "WLSSECRET": "",
        "LICENSEID": ,
    }

    with gp.Env(params=options) as env, gp.Model(env=env) as model:
        # Create a new model
        model = gp.Model("CC_ILP_solver")

        # Add variables
        vertices_num = graph.number_of_nodes()
        x = {}
        for i in range(vertices_num):
            for j in range(i + 1, vertices_num):
                x[i, j] = model.addVar(vtype=gp.GRB.BINARY, name="x(%s,%s)" % (i, j))

        # Add constraints
        for i in range(vertices_num):
            for j in range(i + 1, vertices_num):
                for k in range(j + 1, vertices_num):
                    if graph.has_edge(i, j) and graph.has_edge(j, k) and not graph.has_edge(i, k):
                        model.addConstr(x[i, j] + x[j, k] >= x[i, k], "c(%s,%s,%s)" % (i, j, k))
                    elif graph.has_edge(i, j) and graph.has_edge(i, k) and not graph.has_edge(j, k):
                        model.addConstr(x[i, j] + x[i, k] >= x[j, k], "c(%s,%s,%s)" % (i, j, k))
                    elif graph.has_edge(i, k) and graph.has_edge(j, k) and not graph.has_edge(i, j):
                        model.addConstr(x[i, k] + x[j, k] >= x[i, j], "c(%s,%s,%s)" % (i, j, k))

        # Set the objective function
        obj = sum(x[i, j] for i in range(vertices_num) for j in range(i + 1, vertices_num) if graph.has_edge(i, j))
        obj += sum(1 - x[i, j] for i in range(vertices_num) for j in range(i + 1, vertices_num) if not graph.has_edge(i, j))
        model.setObjective(obj, sense=gp.GRB.MINIMIZE)

        # Solve the model
        model.optimize()

    # Write the solution
    with open(opt_file, "w") as f:
        for i in range(vertices_num):
            for j in range(i + 1, vertices_num):
                f.write(f"{i} {j} {x[i, j].x}\n")


def generate_perturbed_prediction(OPT_file, perturbation, prediction_file):
    '''
    This function generates perturbed predictions from the optimal solution.
    '''
    with open(OPT_file, "r") as file_read, \
        open(prediction_file, "w") as file_write:

        for line in file_read:
            u, v, opt_val = line.strip().split()
            u, v, opt_val = int(u), int(v), float(opt_val)
            if opt_val == 1:
                file_write.write(f"{u} {v} {1 - perturbation}\n")
            else:
                file_write.write(f"{u} {v} {perturbation}\n")


def generate_spectral_embedding_prediction(input_graph_file, cluster_num, prediction_file): 
    '''
    This function generates spectral embedding predictions.
    '''
    graph = nx.Graph()
    graph = nx.read_edgelist(input_graph_file, nodetype=int, create_using=nx.Graph())
    
    # compute spectral embedding
    se = SpectralEmbedding(n_components=cluster_num,  
                        affinity='nearest_neighbors',  
                        random_state=0)

    embeddings = se.fit_transform(nx.to_numpy_array(graph))
    
    with open(prediction_file, 'w') as f:
        for u in range(graph.number_of_nodes()):
            for v in range(u + 1, graph.number_of_nodes()):
                similarity = np.dot(embeddings[u], embeddings[v]) / (np.linalg.norm(embeddings[u]) * np.linalg.norm(embeddings[v]))
                normalized_similarity = abs(similarity)
                f.write(f'{u} {v} {1 - normalized_similarity}\n')


if __name__ == "__main__":

    ### synthetic datasets ###
    num_of_nodes=[100]
    probability_list=[0.9, 0.8, 0.7]
    pertubation_list = [0.1, 0.12, 0.14, 0.16, 0.18, 0.2, 0.22, 0.24, 0.26, 0.28]
    
    # using ground truth to generate predictions
    '''
    generate_ground_truth('./SBM', num_of_nodes, probability_list)
    
    for num in num_of_nodes:
        for prob in probability_list:
            OPT_file = os.path.join('./SBM', f'nodes_{num}/prob_{prob}/ground_truth.txt')
            for perturbation in pertubation_list:
                prediction_file = os.path.join('./SBM', f'nodes_{num}/prob_{prob}/prediction_gt_{perturbation}.txt')
                generate_perturbed_prediction(OPT_file, perturbation, prediction_file)
    '''
                
    # using ILP to generate predictions
    # '''
    for num in num_of_nodes:
        for prob in probability_list:
            input_graph_file = os.path.join('./SBM', f'nodes_{num}/prob_{prob}/edges.txt')
            opt_file = os.path.join('./SBM', f'nodes_{num}/prob_{prob}/opt_solution.txt')
            # get_OPT_from_ILP(input_graph_file, opt_file)

            for perturbation in pertubation_list:
                prediction_file = os.path.join('./SBM', f'nodes_{num}/prob_{prob}/prediction_opt_{perturbation}.txt')
                generate_perturbed_prediction(opt_file, perturbation, prediction_file)
    # '''

    ### first type datasets ###
    # datasets = ['facebook0', 'facebook414', 'facebook3980']
    datasets = ['facebook3980']
    '''
    pertubation_list = [0.002, 0.004, 0.006, 0.008, 0.01, 0.012, 0.014, 0.016, 0.018, 0.02]
    for dataset in datasets:
        graph_file = os.path.join('./facebook', f'{dataset}/edges.txt')
        opt_file = os.path.join('./facebook', f'{dataset}/OPT_sol.txt')
        # get_OPT_from_ILP(graph_file, opt_file)

        for perturbation in pertubation_list:
            prediction_file = os.path.join('./facebook', f'{dataset}/prediction_opt_{perturbation}.txt')
            generate_perturbed_prediction(opt_file, perturbation, prediction_file)
    '''

    ### second type datasets ###
    # email-Eu-core.txt
    '''
    cluster_num_list = [600, 650, 700, 750, 800, 850, 900, 950, 1000]
    for cluster_num in cluster_num_list:
        prediction_file = os.path.join('./email-Eu-core', f"prediction_se_{cluster_num}.txt")
        generate_spectral_embedding_prediction('./email-Eu-core/email-Eu-core.txt', cluster_num, prediction_file)
    '''
    # lastfm_asia.txt
    '''
    cluster_num_list = [5500, 6000, 6500]
    for cluster_num in cluster_num_list:
        prediction_file = os.path.join('./lastfm_asia', f"prediction_se_{cluster_num}.txt")
        generate_spectral_embedding_prediction('./lastfm_asia/edges.txt', cluster_num, prediction_file)
    '''

    pass